package check

import (
	"path/filepath"

	"gitee.com/u-language/u-language/ucom/ast"
	"gitee.com/u-language/u-language/ucom/data"
	"gitee.com/u-language/u-language/ucom/enum"
	"gitee.com/u-language/u-language/ucom/errcode"
	"gitee.com/u-language/u-language/ucom/internal/errutil"
	"gitee.com/u-language/u-language/ucom/internal/utils"
)

var Deferfunc func() = utils.Deferfunc

// CheckTree 检查一颗ast树
//   - tree是要检查的ast树
//   - errctx是错误处理上下文
func CheckTree(tree *ast.Tree, errctx *errcode.ErrCtx, IsInPackage bool) {
	if tree == nil {
		return
	}
	if tree.PackageName != "main" && tree.PackageName != "" && tree.PackageName != ret_dir_name(tree.Filename) { //如果包名不等于目录名
		errctx.Panic(tree.Filename, 1, errcode.NewMsgPackageNameNoEqualDirName(tree.PackageName, utils.NoExtBase(tree.Filename)), errcode.PackageNameNoEqualDirName)
	}
	m := utils.MapSet.Get().(map[string]struct{})
	defer utils.PutMapSet(m)
	ret := checkCodeBlock(0, len(tree.Nodes), tree.Sbt, make(symbolCheckTable), m, tree, nil, false, make([]gotoStmtInfo, 0), false, "", false, false, tree.Nodes, 0)
	for _, v := range ret {
		errctx.Panic(tree.Filename, v.Line, v.Msg, v.Code)
	}
	if !IsInPackage { //如果不在包中
		checkAllStruct(tree.Filename, tree.Sbt, errctx)
		checkAllImport(tree.Filename, tree.ImportPackage, errctx)
		check_Generic(tree, errctx, tree.GenInstNodes)
		checkMain(tree.Sbt, tree.PackageName, errctx, tree.Filename)
	}
}

// checkMain 检查main函数
//   - 在main包，必须有main函数
//   - 在非main包，必须没有main函数
func checkMain(sbt *ast.Sbt, packageName string, errctx *errcode.ErrCtx, file string) {
	var ok bool
	func() {
		defer func() {
			if err := recover(); err != nil {
				if _, ok := errutil.IsUnknownSymbol(err); ok {
					return
				}
				panic(err)
			}
		}()
		_ = sbt.Have("main")
		ok = true
	}()
	if packageName == "main" { //如果是main包
		if !ok {
			errctx.Panic(file, -1, nil, errcode.MainPackageMustHaveMainFunc)
		}
	}
	//Note:非main包由ast进行检查
}

// ret_dir_name 返回目录名
func ret_dir_name(s string) string {
	if filepath.Ext(s) != "" { //如果路径最后是文件名
		s = filepath.Dir(s)
	}
	return utils.NoExtBase(s)
}

// CheckPackage 检查一个包的所有抽象语法树
func CheckPackage(p *ast.Package, errctx *errcode.ErrCtx, Thread bool) {
	for _, v := range p.Trees.Data {
		CheckTree(v, errctx, Thread)
	}
	checkAllStruct(*p.PackageName.Load(), p.Sbt, errctx)
	checkAllImport(*p.PackageName.Load(), p.ImportPackage, errctx)
	check_Generic(p.Trees.Data[0], errctx, p.GenInstNodes)
	checkMain(p.Sbt, *p.PackageName.Load(), errctx, *p.PackageName.Load())
}

// checkAllStruct 检查所有结构体
//
//	1.是否是递归类型
//	2.字段是否重名
func checkAllStruct(packagename string, sbt *ast.Sbt, errctx *errcode.ErrCtx) {
	dep := findStruct(sbt)
	for _, v := range dep {
		decl := v.Info.(*ast.StructDecl)
		if len(decl.TypeParame) != 0 { //跳过未实例化结构体
			continue
		}
		code, msg := data.CheckIncludingSelf(v.Info.(data.IsItARecursiveType))
		if code != errcode.NoErr {
			errctx.Panic(packagename, -1, msg, code)
		}
		code, msg, LineNum := checkStructField(decl)
		if code != errcode.NoErr {
			errctx.Panic(packagename, LineNum, msg, code)
		}
	}
}

// checkAllImport 检查所有导入的包是否循环导入
func checkAllImport(packagename string, ImportPackage *map[string]*ast.Package, errctx *errcode.ErrCtx) {
	if ImportPackage == nil {
		return
	}
	for _, v := range *ImportPackage {
		err, msg := data.CheckIncludingSelf(v)
		if err != errcode.NoErr {
			errctx.Panic(packagename, -1, msg, errcode.CircularImportErr)
		}
	}
}

// findStruct 寻找所有的结构体符号
func findStruct(sbt *ast.Sbt) (ret []ast.SymbolInfo) {
	sbt.Range(func(key string, value ast.SymbolInfo) bool {
		if value.Kind == enum.SymbolStruct {
			ret = append(ret, value)
		}
		return true
	})
	return ret
}

// checkStructField 检查结构体字段
// 检查结构体字段是否重名
func checkStructField(v *ast.StructDecl) (errcode.ErrCode, errcode.Msg, int) {
	m := utils.MapSet.Get().(map[string]struct{})
	defer utils.PutMapSet(m)
	for _, t := range v.FieldTable {
		_, ok := m[t.Name]
		if ok { //如果结构体有重名字段
			return errcode.FieldDupName, errcode.NewMsgSymbol(t.Name), v.LineNum
		}
		m[t.Name] = struct{}{}
	}
	return errcode.NoErr, nil, 0
}

// check_recover 包check panic恢复程序
func check_recover(code *errcode.ErrCode, msg *errcode.Msg) {
	if err := recover(); err != nil {
		varname, ok := errutil.IsUnknownSymbol(err)
		if ok {
			*code, *msg = errcode.UnknownSymbol, errcode.NewMsgUnknownSymbol(varname)
			return
		}
		typ, ok := errutil.IsUnknownType(err)
		if ok {
			*code, *msg = errcode.UnknownType, errcode.NewMsgUnknownSymbol(typ)
			return
		}
		panic(err)
	}
}

func check_Generic(tree *ast.Tree, errctx *errcode.ErrCtx, GenInstNodes *ast.Sbt) {
	var code errcode.ErrCode
	var msg errcode.Msg
	GenInstNodes.Range(func(_ string, info ast.SymbolInfo) bool {
		value := info.Info.(ast.RemoveGenericsInfo)
		defer func() {
			if code != errcode.NoErr {
				errctx.Panic(value.FileName, -1, msg, code)
			}
		}()
		defer check_recover(&code, &msg)
		nodes := value.Nodes
		m := utils.MapSet.Get().(map[string]struct{})
		defer utils.PutMapSet(m)
		ret := checkCodeBlock(0, len(nodes), tree.Sbt, make(symbolCheckTable), m, tree, nil, false, make([]gotoStmtInfo, 0), false, "", false, true, nodes, value.LineNum-1)
		for _, v := range ret {
			errctx.Panic(value.FileName, v.Line, v.Msg, v.Code)
		}
		return true
	})
}
